{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset,load_metric\n",
    "from transformers import AutoTokenizer, AutoModelForMultipleChoice, TrainingArguments, Trainer, set_seed\n",
    "from dataclasses import dataclass\n",
    "from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy\n",
    "from typing import Optional, Union\n",
    "import torch\n",
    "import numpy as np\n",
    "from huggingface_hub import notebook_login\n",
    "from datasets import set_caching_enabled\n",
    "set_caching_enabled(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set seed "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model and tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# model_checkpoint = \"michiyasunaga/BioLinkBERT-base\"\n",
    "model_checkpoint = 'bert-base-uncased'\n",
    "\n",
    "model = AutoModelForMultipleChoice.from_pretrained(model_checkpoint)\n",
    "# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=512) # truncation of long questions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load preprocessed QA dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# preprocessed questions from here: https://github.com/michiyasunaga/LinkBERT\n",
    "\n",
    "datafiles = {}\n",
    "datafiles['train'] = 'preprocessed_questions/train.json'\n",
    "datafiles['test'] = 'preprocessed_questions/test.json'\n",
    "datafiles['validation'] = 'preprocessed_questions/dev.json'\n",
    "\n",
    "qa_dataset = load_dataset('json', data_files=datafiles)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tokenize the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ending_names = [\"ending0\", \"ending1\", \"ending2\", \"ending3\"]\n",
    "\n",
    "\n",
    "def preprocess_function(examples):\n",
    "    first_sentences = [[context] * 4 for context in examples[\"sent1\"]]\n",
    "    question_headers = examples[\"sent2\"]\n",
    "    second_sentences = [\n",
    "        [f\"{header} {examples[end][i]}\" for end in ending_names] for i, header in enumerate(question_headers)\n",
    "    ]\n",
    "\n",
    "    first_sentences = sum(first_sentences, [])\n",
    "    second_sentences = sum(second_sentences, [])\n",
    "\n",
    "    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)\n",
    "    return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}\n",
    "\n",
    "tokenized_qa = qa_dataset.map(preprocess_function, batched=True)\n",
    "tokenized_qa\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data Collator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class DataCollatorForMultipleChoice:\n",
    "    \"\"\"\n",
    "    Data collator that will dynamically pad the inputs for multiple choice received.\n",
    "    \"\"\"\n",
    "\n",
    "    tokenizer: PreTrainedTokenizerBase\n",
    "    padding: Union[bool, str, PaddingStrategy] = True\n",
    "    max_length: Optional[int] = None\n",
    "    pad_to_multiple_of: Optional[int] = None\n",
    "\n",
    "    def __call__(self, features):\n",
    "        label_name = \"label\" if \"label\" in features[0].keys() else \"labels\"\n",
    "        labels = [feature.pop(label_name) for feature in features]\n",
    "        batch_size = len(features)\n",
    "        num_choices = len(features[0][\"input_ids\"])\n",
    "        flattened_features = [\n",
    "            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features\n",
    "        ]\n",
    "        flattened_features = sum(flattened_features, [])\n",
    "\n",
    "        batch = self.tokenizer.pad(\n",
    "            flattened_features,\n",
    "            padding=self.padding,\n",
    "            max_length=self.max_length,\n",
    "            pad_to_multiple_of=self.pad_to_multiple_of,\n",
    "            return_tensors=\"pt\",\n",
    "        )\n",
    "\n",
    "        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}\n",
    "        batch[\"labels\"] = torch.tensor(labels, dtype=torch.int64)\n",
    "        return batch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Log into the hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# need to log into huggingface\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute metrics for accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "metric = load_metric('accuracy')\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    predictions, labels = eval_pred\n",
    "    predictions = np.argmax(predictions, axis=1)\n",
    "    return metric.compute(predictions=predictions, references=labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training arguments and trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_args = TrainingArguments(\n",
    "    # output_dir='medqa_fine_tuned_linkbert',\n",
    "    output_dir='medqa_fine_tuned_generic_bert',\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    learning_rate=1e-5,\n",
    "    per_device_train_batch_size=4,\n",
    "    per_device_eval_batch_size=4,\n",
    "    num_train_epochs=5,\n",
    "    weight_decay=0.01,\n",
    "    push_to_hub=True,\n",
    "    save_total_limit = 1,\n",
    "    resume_from_checkpoint=True,\n",
    "    warmup_steps = 100,\n",
    "    gradient_accumulation_steps = 8,    \n",
    ")\n",
    "\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=tokenized_qa[\"train\"],\n",
    "    eval_dataset=tokenized_qa[\"validation\"],\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),\n",
    "    compute_metrics=compute_metrics\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Push to hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.push_to_hub()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "469bf39fb786ce62fe1cf28daf5532353697c7f878bc2a141a4af79f26162277"
  },
  "kernelspec": {
   "display_name": "Python 3.10.4 ('medqa')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
